import sys
import torch
import argparse
#import time
import ot

import numpy as np
import torch.nn as nn
import torch.nn.functional as F

sys.path.append("../lib")

from sw_sphere import sliced_wasserstein_sphere, sliced_wasserstein_sphere_unif
from utils_sphere import *
from utils_vmf import *


parser = argparse.ArgumentParser()
parser.add_argument("--ntry", type=int, default=20, help="number of restart")
args = parser.parse_args()

device = "cuda" if torch.cuda.is_available() else "cpu"




if __name__ == "__main__":    
    n_try = args.ntry

    ds = [3, 10, 20, 50 ,100]
    n_samples = [10,100,1000,10000]
    n_projs = 100

    L_ssw = np.zeros((len(ds),len(n_samples),n_try))
    L_w = np.zeros((len(ds), len(n_samples), n_try))

    for j, d in enumerate(ds):
        print(d, flush=True)
        
        for k in range(n_try):
            for i, n in enumerate(n_samples):
                x0 = F.normalize(torch.randn((n, d), device=device), dim=-1, p=2)
                x1 = F.normalize(torch.randn((n, d), device=device), dim=-1, p=2)

                sw = sliced_wasserstein_sphere(x0, x1, n_projs, device)
                L_ssw[j, i, k] = sw.item()

                ip = x0@x1.T
                M = torch.arccos(torch.clamp(ip, min=-1+1e-5, max=1-1e-5))
                a = torch.ones(x0.shape[0], device=device) / x0.shape[0]
                b = torch.ones(x1.shape[0], device=device) / x1.shape[0]
                w = ot.emd2(a, b, M).item()
                L_w[j, i, k] = w


    for j, d in enumerate(ds):
        np.savetxt("./ssw_sample_d"+str(d), L[j], delimiter=",")
        np.savetxt("./w_sample_d"+str(d), L_w[j], delimiter=",")
        
        
